Skip to content

Add wandb support#3053

Open
Zephyr271828 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
Zephyr271828:wandb
Open

Add wandb support#3053
Zephyr271828 wants to merge 1 commit into
AI-Hypercomputer:mainfrom
Zephyr271828:wandb

Conversation

@Zephyr271828
Copy link
Copy Markdown

@Zephyr271828 Zephyr271828 commented Jan 30, 2026

Description

This PR aims to implement #2434 and add wandb logging support to MaxText.

Implementation details

The implementation of wandb logging simply follows the style of other logging interfaces.

Initialization

class MetricLogger:
  """
  Logger for saving metrics to a local file, GCS and TensorBoard.
  """

  def __init__(self, config, learning_rate_schedule):
    self.writer = max_utils.initialize_summary_writer(config.tensorboard_dir, config.run_name)
    self.config = config
    self.metadata = {}
    self.running_gcs_metrics = [] if config.gcs_metrics else None
    self.performance_metric_queue = self.get_performance_metric_queue(config)
    self.learning_rate_schedule = learning_rate_schedule
    self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
    self.buffered_train_metrics = None
    
    if self.config.managed_mldiagnostics:
      ManagedMLDiagnostics(config)  # Initialize the MLRun instance.
      
    self.enable_wandb = self.config.enable_wandb and socket.gethostname().endswith("-0") # you should only init wandb on one host.
    if self.enable_wandb: 
      wandb.init(
        project=config.wandb_project_name,
        name=config.wandb_run_name,
        resume="allow",
      ) # Initialize wandb logger.

Logging step

  def write_metrics(self, metrics, step, is_training=True):
    """Entry point for all metrics writing in Train's Main."""
    if metrics:
      self.log_metrics(metrics, step, is_training)

      if self.config.enable_tensorboard:
        self.write_metrics_to_tensorboard(metrics, step, is_training)

      if self.config.metrics_file:
        self.write_metrics_locally(metrics, step)

      if self.config.gcs_metrics and jax.process_index() == 0:
        self.write_metrics_for_gcs(metrics, step, is_training)

      if self.config.managed_mldiagnostics:
        self.write_metrics_to_managed_mldiagnostics(metrics, step)
        
      if self.enable_wandb:
        self.write_metrics_to_wandb(metrics, step)
  def write_metrics_to_wandb(self, metrics, step):
    """Write metrics to weights and biases (wandb)."""
    flat_metrics = {}
    for key, val in metrics.get("scalar", {}).items():
      flat_metrics[key] = float(val)
    for key, val in metrics.get("scalars", {}).items():
      for subkey, subval in val.items():
        flat_metrics[f"{key}/{subkey}"] = float(subval)
    wandb.log(flat_metrics, step=step)

Usage

python -u -m src.MaxText.train src/MaxText/configs/base.yml \
    ...
    enable_wandb=True \
    wandb_project_name=xxx \
    wandb_run_name=yyy \
    ...

Limitations

Currently this implementation does not support resuming from an existing wandb run. In order to resume, we need to first retrieve the run_id from somewhere, then do

wandb.init(
    project=config.wandb_project,
    name=config.wandb_run_name,
    id=run_id,
    resume="allow",
)

It makes sense to save the run_ids at some cache dir inside of the maxtext repo, but I don't know whether that's consistent with the design philosophy of this project.

Tests

Example training script:

#!/bin/bash

python tools/orchestration/multihost_runner.py \
    --TPU_PREFIX=${TPU_PREFIX} \
    --COMMAND="
    export WANDB_API_KEY=''
    export PYTHONPATH=./src:\${PYTHONPATH:-''}
    python -m src.MaxText.train src/MaxText/configs/base.yml \
        run_name=${RUN_NAME} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        dataset_type=grain \
        grain_train_files=${DATA_FILES} \
        grain_file_type='arrayrecord' \
        grain_worker_count=1 \
        enable_data_shuffling=${SHUFFLE} \
        tokenize_train_data=False \
        tokenize_eval_data=False \
        max_target_length=${SEQ_LEN} \
        async_checkpointing=${ASYNC_CHECKPOINTING} \
        model_name=${MODEL_NAME} \
        steps=${NUM_STEPS} \
        per_device_batch_size=${BATCH_SIZE} \
        gradient_accumulation_steps=${GRAD_ACCUM} \
        gradient_clipping_threshold=${GRAD_CLIP} \
        learning_rate=${LR} \
        warmup_steps_fraction=${WARMUP_RATIO} \
        checkpoint_period=500 \
        enable_wandb=True \
        wandb_project_name=${WANDB_PROJ_NAME} \
        wandb_run_name=${TPU_PREFIX}_${RUN_NAME} \
        packing=false \
    "

Outputs:

Per train step:
 Total TFLOPs: 104.77 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0131 11:17:43.038159 136759033735168 max_utils.py:695] Total memory size: 6.4 GB, Output size: 0.2 GB, Temp size: 6.2 GB, Argument size: 0.2 GB, Host temp size: 0.0 GB.
wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
wandb: Currently logged in as: yx3038 (yx3038-new-york-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: setting up run negt2tz0
wandb: Tracking run with wandb version 0.24.1
wandb: Run data is saved locally in /home/zephyr/2026-01-31-11-16-10/wandb/run-20260131_111743-negt2tz0
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run yufeng-v6e-32-0003_qwen3-0.6b_L200_seqlen_8192_bs_1_grad_accum_2_lr_0.0003_min_lr_ratio_0.1_warmup_ratio_0.05
wandb: ⭐️ View project at https://wandb.ai/yx3038-new-york-university/llm_pruning
wandb: 🚀 View run at https://wandb.ai/yx3038-new-york-university/llm_pruning/runs/negt2tz0
I0131 11:17:44.613295 136759033735168 metric_logger.py:297] number parameters: 0.596 billion
I0131 11:17:44.614987 136707965978176 grain_pool.py:367] Grain pool will use 1 processes.
I0131 11:17:44.618239 136707965978176 grain_pool.py:440] Grain pool will start child processes.
I0131 11:17:44.620215 136707965978176 grain_pool.py:448] Grain pool started all child processes.
2026-01-31 11:17:46.845550: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-01-31 11:17:46.845879: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-31 11:17:46.882721: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-01-31 11:17:48.106673: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-31 11:17:48.107431: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-01-31 11:17:49.626479: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0131 11:17:58.143736 136759033735168 max_utils.py:654] 
Memstats: After params initialized:
I0131 11:17:58.144038 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_10(process=3,(2,2,0,0))
I0131 11:17:58.144209 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_11(process=3,(3,2,0,0))
I0131 11:17:58.144291 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_14(process=3,(2,3,0,0))
I0131 11:17:58.144401 136759033735168 max_utils.py:660] 	Using (GB) 0.25 / 31.25 (0.800000%) on TPU_15(process=3,(3,3,0,0))
I0131 11:18:29.464875 136759033735168 metric_logger.py:193] completed step: 1, seconds: 13.530, TFLOP/s/device: 7.743, Tokens/s/device: 1210.946, total_weights: 524224, loss: 252.466
I0131 11:18:30.696314 136759033735168 metric_logger.py:193] completed step: 2, seconds: 0.320, TFLOP/s/device: 327.449, Tokens/s/device: 51208.161, total_weights: 524224, loss: 252.398
I0131 11:18:31.927749 136759033735168 metric_logger.py:193] completed step: 3, seconds: 31.010, TFLOP/s/device: 3.378, Tokens/s/device: 528.341, total_weights: 524224, loss: 251.585
I0131 11:18:33.159628 136759033735168 metric_logger.py:193] completed step: 4, seconds: 1.231, TFLOP/s/device: 85.084, Tokens/s/device: 13305.786, total_weights: 524224, loss: 250.936
I0131 11:18:34.390431 136759033735168 metric_logger.py:193] completed step: 5, seconds: 1.231, TFLOP/s/device: 85.076, Tokens/s/device: 13304.576, total_weights: 524224, loss: 248.162
I0131 11:18:35.622099 136759033735168 metric_logger.py:193] completed step: 6, seconds: 1.232, TFLOP/s/device: 85.047, Tokens/s/device: 13300.062, total_weights: 524224, loss: 245.645
I0131 11:18:36.853450 136759033735168 metric_logger.py:193] completed step: 7, seconds: 1.230, TFLOP/s/device: 85.162, Tokens/s/device: 13318.008, total_weights: 524224, loss: 243.418
I0131 11:18:38.084757 136759033735168 metric_logger.py:193] completed step: 8, seconds: 1.232, TFLOP/s/device: 85.046, Tokens/s/device: 13299.846, total_weights: 524224, loss: 241.825
I0131 11:18:39.316017 136759033735168 metric_logger.py:193] completed step: 9, seconds: 1.231, TFLOP/s/device: 85.090, Tokens/s/device: 13306.867, total_weights: 524224, loss: 237.915
I0131 11:18:40.547462 136759033735168 metric_logger.py:193] completed step: 10, seconds: 1.231, TFLOP/s/device: 85.084, Tokens/s/device: 13305.829, total_weights: 524224, loss: 233.977
...

Wandb output:
截屏2026-01-31 19 29 26

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Jan 30, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@Zephyr271828 Zephyr271828 marked this pull request as draft January 30, 2026 13:49
@Zephyr271828 Zephyr271828 marked this pull request as ready for review January 31, 2026 11:41
Comment thread src/maxtext/common/metric_logger.py Outdated
if self.config.managed_mldiagnostics:
ManagedMLDiagnostics(config) # Initialize the MLRun instance.

self.enable_wandb = self.config.enable_wandb and socket.gethostname().endswith("-0") # you should only init wandb on one host.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check jax process index instead.

Comment thread src/maxtext/common/metric_logger.py
Comment thread src/maxtext/common/metric_logger.py
@dipannita08
Copy link
Copy Markdown
Collaborator

Thanks @Zephyr271828, could you also share any pre-requisites to generate the WANDB_API_KEY?

@Zephyr271828
Copy link
Copy Markdown
Author

Thanks @Zephyr271828, could you also share any pre-requisites to generate the WANDB_API_KEY?

Thank you for your detailed review! I will modify the code accordingly.

I don't think any prerequisite is needed to generate the API key? You may simply go to wandb.ai/settings#apikeys to generate an API key. Then you only need to set export WANDB_API_KEY=<your_api_key> in your script and install wandb, then you are good to go.

@kryvokhyzha
Copy link
Copy Markdown
Contributor

@Zephyr271828 Hi!
I appreciate your PR, and I'm interested in adding wandb to maxtext. Let me know if you need any help

@Zephyr271828
Copy link
Copy Markdown
Author

Zephyr271828 commented Feb 20, 2026

@shralex @dipannita08 @gagika @kryvokhyzha Thank you for your support and detailed feedback! Below is a summary of the potential improvements you suggested:

  • use jax.process instead of socket to identify rank 0
  • check if logging wandb metrics and conversion (float) affects performance
  • log more metrics

Rank 0 detection

See here.

Performance

Due to limited TPU resources I have, I tested the performance of training qwen3-0.6b from scratch w/ and w/o wandb on v4-16 (2 tpu vms) to simulate a basic multi-host training setup.

command
#!/bin/bash
set -euo pipefail

source get_tpu_bucket_name.sh

export TPU_PREFIX="$(get_tpu_name)"
export BUCKET_NAME="$(get_bucket_name)"
export NUM_HOSTS=$(get_num_hosts)

for arg in "$@"; do
    case $arg in
        --lr=*) LR="${arg#*=}" ;;
        --batch_size=*) BATCH_SIZE="${arg#*=}" ;;
        --global_batch_size=*) GLOBAL_BATCH_SIZE="${arg#*=}" ;;
        --grad_clip=*) GRAD_CLIP="${arg#*=}" ;;
        --min_lr_ratio=*) MIN_LR_RATIO="${arg#*=}" ;;
        --warmup_ratio=*) WARMUP_RATIO="${arg#*=}" ;;
        --max_to_keep=*) MAX_TO_KEEP="${arg#*=}" ;;
        --data_files=*) DATA_FILES="${arg#*=}" ;;
        --shuffle=*) SHUFFLE="${arg#*=}" ;;
        --tag=*) TAG="${arg#*=}" ;;
        *) echo "[WARN] Unknown arg $arg" ;;
    esac
done

export MODEL_NAME="qwen3-0.6b"
export NUM_STEPS=50000
export SEQ_LEN=8192
export BATCH_SIZE=${BATCH_SIZE:-1}
export GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-64}
export GRAD_ACCUM=$((GLOBAL_BATCH_SIZE / BATCH_SIZE / NUM_HOSTS / 4))
export GRAD_CLIP=${GRAD_CLIP:-1.0}
export LR=${LR:-0.0003}
export MIN_LR_RATIO=${MIN_LR_RATIO:-0.1}
export WARMUP_RATIO=${WARMUP_RATIO:-0.05}
export ASYNC_CHECKPOINTING=false
export BASE_OUTPUT_DIRECTORY="gs://${BUCKET_NAME}/model_ckpts/maxtext"
export MAX_TO_KEEP=${MAX_TO_KEEP:-1}
export DATA_FILES="${DATA_FILES:-/home/zephyr/gcs-bucket/datasets/dclm/llama3_array_record_with_special_tokens_64/*.array_record}"
export SHUFFLE="${SHUFFLE:-True}"
export RUN_NAME="${MODEL_NAME}_L200_seqlen_${SEQ_LEN}_bs_${BATCH_SIZE}_grad_accum_${GRAD_ACCUM}_lr_${LR}_min_lr_ratio_${MIN_LR_RATIO}_warmup_ratio_${WARMUP_RATIO}"
if [ ! -z "${TAG:-}" ]; then
    export RUN_NAME="${RUN_NAME}_${TAG}"
fi
export JAX_PLATFORMS=tpu
export SPARSE_MODEL_TRAINING=False

export PYTHONPATH=./src:${PYTHONPATH:-''}
python -u multihost_runner_orig.py \
    --TPU_PREFIX=${TPU_PREFIX} \
    --COMMAND="
    export TPU_LOG_DIR=/home/zephyr/tpu_logs
    export WANDB_API_KEY='7d11bbca76b3081b6bd1efbbcf1572aab26c5d56'
    source ~/maxtext_env_py311/bin/activate
    export PYTHONPATH=./src:\${PYTHONPATH:-''}
    ~/maxtext_env_py311/bin/python -u -m src.MaxText.train src/MaxText/configs/base.yml \
        run_name=${RUN_NAME} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        dataset_type=grain \
        grain_train_files=${DATA_FILES} \
        grain_file_type='arrayrecord' \
        grain_worker_count=1 \
        enable_data_shuffling=${SHUFFLE} \
        tokenize_train_data=False \
        tokenize_eval_data=False \
        max_target_length=${SEQ_LEN} \
        async_checkpointing=${ASYNC_CHECKPOINTING} \
        model_name=${MODEL_NAME} \
        steps=${NUM_STEPS} \
        per_device_batch_size=${BATCH_SIZE} \
        gradient_accumulation_steps=${GRAD_ACCUM} \
        gradient_clipping_threshold=${GRAD_CLIP} \
        learning_rate=${LR} \
        warmup_steps_fraction=${WARMUP_RATIO} \
        checkpoint_period=500 \
        enable_wandb=True \
        wandb_project_name=llm_pruning \
        wandb_run_name=${TPU_PREFIX}_${RUN_NAME} \
        packing=false \
        sharding_tolerance=0.5 \
    "
w/o wandb logs
I0220 17:10:13.291728 139933146753024 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
Per train step:
 Total TFLOPs: 419.07 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:10:13.300330 139933146753024 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:10:13.362068 139847774729792 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:10:13.366581 139847774729792 grain_pool.py:440] Grain pool will start child processes.
I0220 17:10:13.369205 139847774729792 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:10:16.334197: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:10:16.379936: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:10:17.931806: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:10:19.885891: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:10:30.075978 139933146753024 max_utils.py:654] 
Memstats: After params initialized:
I0220 17:10:30.076180 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:10:30.076251 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:10:30.076312 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:10:30.076370 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:10:44.227162 139933146753024 metric_logger.py:194] completed step: 1, seconds: 16.716, TFLOP/s/device: 25.070, Tokens/s/device: 3920.633, total_weights: 524224, loss: 249.849
I0220 17:10:50.938067 139933146753024 metric_logger.py:194] completed step: 2, seconds: 0.478, TFLOP/s/device: 876.634, Tokens/s/device: 137092.270, total_weights: 524224, loss: 250.093
I0220 17:10:57.648830 139933146753024 metric_logger.py:194] completed step: 3, seconds: 13.688, TFLOP/s/device: 30.616, Tokens/s/device: 4787.852, total_weights: 524224, loss: 249.443
I0220 17:11:04.359475 139933146753024 metric_logger.py:194] completed step: 4, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.214, total_weights: 524224, loss: 247.791
I0220 17:11:11.070475 139933146753024 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.462, total_weights: 524224, loss: 245.704
I0220 17:11:17.781303 139933146753024 metric_logger.py:194] completed step: 6, seconds: 6.713, TFLOP/s/device: 62.423, Tokens/s/device: 9761.967, total_weights: 524224, loss: 242.812
I0220 17:11:24.492303 139933146753024 metric_logger.py:194] completed step: 7, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.032, total_weights: 524224, loss: 242.086
I0220 17:11:31.203217 139933146753024 metric_logger.py:194] completed step: 8, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.048, total_weights: 524224, loss: 238.795
I0220 17:11:37.914210 139933146753024 metric_logger.py:194] completed step: 9, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.406, total_weights: 524224, loss: 235.700
I0220 17:11:44.624889 139933146753024 metric_logger.py:194] completed step: 10, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.869, total_weights: 524224, loss: 230.782
I0220 17:11:51.335738 139933146753024 metric_logger.py:194] completed step: 11, seconds: 6.711, TFLOP/s/device: 62.442, Tokens/s/device: 9765.028, total_weights: 524224, loss: 228.549
I0220 17:11:58.046651 139933146753024 metric_logger.py:194] completed step: 12, seconds: 6.710, TFLOP/s/device: 62.455, Tokens/s/device: 9766.943, total_weights: 524224, loss: 225.022
I0220 17:12:04.757413 139933146753024 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.717, total_weights: 524224, loss: 218.865
I0220 17:12:11.468343 139933146753024 metric_logger.py:194] completed step: 14, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.496, total_weights: 524224, loss: 214.666
I0220 17:12:18.179343 139933146753024 metric_logger.py:194] completed step: 15, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.202, total_weights: 524224, loss: 208.936
I0220 17:12:24.890058 139933146753024 metric_logger.py:194] completed step: 16, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.237, total_weights: 524224, loss: 206.809
I0220 17:12:31.600906 139933146753024 metric_logger.py:194] completed step: 17, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.589, total_weights: 524224, loss: 201.072
I0220 17:12:38.311885 139933146753024 metric_logger.py:194] completed step: 18, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.096, total_weights: 524224, loss: 195.263
I0220 17:12:45.023010 139933146753024 metric_logger.py:194] completed step: 19, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.346, total_weights: 524224, loss: 191.112
I0220 17:12:51.733779 139933146753024 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 185.623
I0220 17:12:58.444595 139933146753024 metric_logger.py:194] completed step: 21, seconds: 6.712, TFLOP/s/device: 62.439, Tokens/s/device: 9764.521, total_weights: 524224, loss: 180.439
I0220 17:13:05.155681 139933146753024 metric_logger.py:194] completed step: 22, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.285, total_weights: 524224, loss: 178.119
I0220 17:13:11.866602 139933146753024 metric_logger.py:194] completed step: 23, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.330, total_weights: 524224, loss: 170.033
I0220 17:13:18.577209 139933146753024 metric_logger.py:194] completed step: 24, seconds: 6.711, TFLOP/s/device: 62.448, Tokens/s/device: 9765.981, total_weights: 524224, loss: 162.621
I0220 17:13:25.288255 139933146753024 metric_logger.py:194] completed step: 25, seconds: 6.713, TFLOP/s/device: 62.430, Tokens/s/device: 9763.065, total_weights: 524224, loss: 159.374
I0220 17:13:31.999057 139933146753024 metric_logger.py:194] completed step: 26, seconds: 6.709, TFLOP/s/device: 62.466, Tokens/s/device: 9768.813, total_weights: 524224, loss: 154.217
I0220 17:13:38.710025 139933146753024 metric_logger.py:194] completed step: 27, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.294, total_weights: 524224, loss: 148.877
I0220 17:13:45.420726 139933146753024 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.813, total_weights: 524224, loss: 144.224
I0220 17:13:52.131641 139933146753024 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.349, total_weights: 524224, loss: 139.123
I0220 17:13:58.842322 139933146753024 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 135.702
w/ wandb logs
I0220 17:16:03.667288 140548037974016 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
wandb: Currently logged in as: yx3038 (yx3038-new-york-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.25.0
wandb: Run data is saved locally in /home/zephyr/2026-02-20-17-14-43/wandb/run-20260220_171603-k0646bs7
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run yufeng-qw-v4-16_2_qwen3-0.6b_L200_seqlen_8192_bs_1_grad_accum_8_lr_0.0003_min_lr_ratio_0.1_warmup_ratio_0.05
wandb: ⭐️ View project at https://wandb.ai/yx3038-new-york-university/llm_pruning
wandb: 🚀 View run at https://wandb.ai/yx3038-new-york-university/llm_pruning/runs/k0646bs7
Per train step:
 Total TFLOPs: 419.07 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:16:04.903574 140548037974016 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:16:04.963233 140462667576896 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:16:04.969832 140462667576896 grain_pool.py:440] Grain pool will start child processes.
I0220 17:16:04.972459 140462667576896 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:16:08.003286: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:16:08.051091: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:16:09.607792: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:16:11.587793: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:16:21.990377 140548037974016 max_utils.py:654] 
Memstats: After params initialized:
I0220 17:16:21.990737 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:16:21.991039 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:16:21.991271 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:16:21.991589 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:16:34.658525 140548037974016 metric_logger.py:194] completed step: 1, seconds: 17.029, TFLOP/s/device: 24.609, Tokens/s/device: 3848.473, total_weights: 524224, loss: 249.849
I0220 17:16:41.369417 140548037974016 metric_logger.py:194] completed step: 2, seconds: 0.497, TFLOP/s/device: 843.702, Tokens/s/device: 131942.291, total_weights: 524224, loss: 250.093
I0220 17:16:48.080118 140548037974016 metric_logger.py:194] completed step: 3, seconds: 12.190, TFLOP/s/device: 34.378, Tokens/s/device: 5376.157, total_weights: 524224, loss: 249.443
I0220 17:16:54.790829 140548037974016 metric_logger.py:194] completed step: 4, seconds: 6.709, TFLOP/s/device: 62.462, Tokens/s/device: 9768.093, total_weights: 524224, loss: 247.791
I0220 17:17:01.501564 140548037974016 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.458, Tokens/s/device: 9767.429, total_weights: 524224, loss: 245.704
I0220 17:17:08.212548 140548037974016 metric_logger.py:194] completed step: 6, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.154, total_weights: 524224, loss: 242.812
I0220 17:17:14.923486 140548037974016 metric_logger.py:194] completed step: 7, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.342, total_weights: 524224, loss: 242.086
I0220 17:17:21.636461 140548037974016 metric_logger.py:194] completed step: 8, seconds: 6.714, TFLOP/s/device: 62.421, Tokens/s/device: 9761.641, total_weights: 524224, loss: 238.795
I0220 17:17:28.345310 140548037974016 metric_logger.py:194] completed step: 9, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.004, total_weights: 524224, loss: 235.700
I0220 17:17:35.056111 140548037974016 metric_logger.py:194] completed step: 10, seconds: 6.716, TFLOP/s/device: 62.401, Tokens/s/device: 9758.650, total_weights: 524224, loss: 230.782
I0220 17:17:41.766973 140548037974016 metric_logger.py:194] completed step: 11, seconds: 6.706, TFLOP/s/device: 62.489, Tokens/s/device: 9772.269, total_weights: 524224, loss: 228.549
I0220 17:17:48.477904 140548037974016 metric_logger.py:194] completed step: 12, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.521, total_weights: 524224, loss: 225.022
I0220 17:17:55.188795 140548037974016 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.652, total_weights: 524224, loss: 218.865
I0220 17:18:01.899731 140548037974016 metric_logger.py:194] completed step: 14, seconds: 6.710, TFLOP/s/device: 62.453, Tokens/s/device: 9766.703, total_weights: 524224, loss: 214.666
I0220 17:18:08.610642 140548037974016 metric_logger.py:194] completed step: 15, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.824, total_weights: 524224, loss: 208.936
I0220 17:18:15.321358 140548037974016 metric_logger.py:194] completed step: 16, seconds: 6.712, TFLOP/s/device: 62.436, Tokens/s/device: 9764.086, total_weights: 524224, loss: 206.809
I0220 17:18:22.032296 140548037974016 metric_logger.py:194] completed step: 17, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.665, total_weights: 524224, loss: 201.072
I0220 17:18:28.743230 140548037974016 metric_logger.py:194] completed step: 18, seconds: 6.708, TFLOP/s/device: 62.470, Tokens/s/device: 9769.367, total_weights: 524224, loss: 195.263
I0220 17:18:35.454278 140548037974016 metric_logger.py:194] completed step: 19, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.703, total_weights: 524224, loss: 191.112
I0220 17:18:42.165038 140548037974016 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.207, total_weights: 524224, loss: 185.623
I0220 17:18:48.876030 140548037974016 metric_logger.py:194] completed step: 21, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.468, total_weights: 524224, loss: 180.439
I0220 17:18:55.586965 140548037974016 metric_logger.py:194] completed step: 22, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.811, total_weights: 524224, loss: 178.119
I0220 17:19:02.298099 140548037974016 metric_logger.py:194] completed step: 23, seconds: 6.715, TFLOP/s/device: 62.405, Tokens/s/device: 9759.255, total_weights: 524224, loss: 170.033
I0220 17:19:09.008782 140548037974016 metric_logger.py:194] completed step: 24, seconds: 6.706, TFLOP/s/device: 62.494, Tokens/s/device: 9773.089, total_weights: 524224, loss: 162.621
I0220 17:19:15.719572 140548037974016 metric_logger.py:194] completed step: 25, seconds: 6.717, TFLOP/s/device: 62.390, Tokens/s/device: 9756.924, total_weights: 524224, loss: 159.374
I0220 17:19:22.430507 140548037974016 metric_logger.py:194] completed step: 26, seconds: 6.706, TFLOP/s/device: 62.491, Tokens/s/device: 9772.699, total_weights: 524224, loss: 154.217
I0220 17:19:29.141228 140548037974016 metric_logger.py:194] completed step: 27, seconds: 6.710, TFLOP/s/device: 62.457, Tokens/s/device: 9767.282, total_weights: 524224, loss: 148.877
I0220 17:19:35.852080 140548037974016 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.672, total_weights: 524224, loss: 144.224
I0220 17:19:42.562924 140548037974016 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.540, total_weights: 524224, loss: 139.123
I0220 17:19:49.273759 140548037974016 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.534, total_weights: 524224, loss: 135.702
conclusion: we can see from the logs that the average training time per step is ~ 6.71s w/ and w/o wandb. Therefore, wandb hardly affects the training performance in my setups. Please let me know if you think testing with more setups or more tpus is necessary.

Logging more metrics

The metrics I log closely follows the implementation for tensorboard logging (i.e., I'm logging exactly the same metrics as tensorboard logging). Below is a list of the currently supported metrics:

  • learning/total_weights
  • learning/raw_grad_norm
  • learning/param_norm
  • learning/mtp_loss
  • learning/moe_lb_loss
  • learning/loss
  • learning/grad_norm
  • learning/current_learning_rate
  • perf/step_time_seconds
  • perf/per_device_tokens_per_sec
  • perf/per_device_tokens
  • perf/per_device_tflops_per_sec
  • perf/per_device_tflops
  • Network Traffic (Bytes)
  • Disk Utilization (GB)
  • Disk Utilization (%)
  • Process CPU Threads in Use
  • Process Memory Available (MB)
  • Process Memory in Use (%)
  • Process Memory in Use (MB)
  • System Memory Utilization (%)
  • TPU Memory Usage (Bytes)
  • TPU Memory Usage (%)
  • TPU Duty Cycle (%)
  • Process CPU Utilization (%)

Please let me know if any other metrics should be added to this list. Personally I think we may add RL metrics in subsequent PRs.

@Zephyr271828
Copy link
Copy Markdown
Author

@shralex @dipannita08 @gagika @kryvokhyzha Thank you for your support and detailed feedback! Below is a summary of the potential improvements you suggested:

  • use jax.process instead of socket to identify rank 0
  • check if logging wandb metrics and conversion (float) affects performance
  • log more metrics

Rank 0 detection

See here.

Performance

Due to limited TPU resources I have, I tested the performance of training qwen3-0.6b from scratch w/ and w/o wandb on v4-16 (2 tpu vms) to simulate a basic multi-host training setup.

command

#!/bin/bash
set -euo pipefail

source get_tpu_bucket_name.sh

export TPU_PREFIX="$(get_tpu_name)"
export BUCKET_NAME="$(get_bucket_name)"
export NUM_HOSTS=$(get_num_hosts)

for arg in "$@"; do
    case $arg in
        --lr=*) LR="${arg#*=}" ;;
        --batch_size=*) BATCH_SIZE="${arg#*=}" ;;
        --global_batch_size=*) GLOBAL_BATCH_SIZE="${arg#*=}" ;;
        --grad_clip=*) GRAD_CLIP="${arg#*=}" ;;
        --min_lr_ratio=*) MIN_LR_RATIO="${arg#*=}" ;;
        --warmup_ratio=*) WARMUP_RATIO="${arg#*=}" ;;
        --max_to_keep=*) MAX_TO_KEEP="${arg#*=}" ;;
        --data_files=*) DATA_FILES="${arg#*=}" ;;
        --shuffle=*) SHUFFLE="${arg#*=}" ;;
        --tag=*) TAG="${arg#*=}" ;;
        *) echo "[WARN] Unknown arg $arg" ;;
    esac
done

export MODEL_NAME="qwen3-0.6b"
export NUM_STEPS=50000
export SEQ_LEN=8192
export BATCH_SIZE=${BATCH_SIZE:-1}
export GLOBAL_BATCH_SIZE=${GLOBAL_BATCH_SIZE:-64}
export GRAD_ACCUM=$((GLOBAL_BATCH_SIZE / BATCH_SIZE / NUM_HOSTS / 4))
export GRAD_CLIP=${GRAD_CLIP:-1.0}
export LR=${LR:-0.0003}
export MIN_LR_RATIO=${MIN_LR_RATIO:-0.1}
export WARMUP_RATIO=${WARMUP_RATIO:-0.05}
export ASYNC_CHECKPOINTING=false
export BASE_OUTPUT_DIRECTORY="gs://${BUCKET_NAME}/model_ckpts/maxtext"
export MAX_TO_KEEP=${MAX_TO_KEEP:-1}
export DATA_FILES="${DATA_FILES:-/home/zephyr/gcs-bucket/datasets/dclm/llama3_array_record_with_special_tokens_64/*.array_record}"
export SHUFFLE="${SHUFFLE:-True}"
export RUN_NAME="${MODEL_NAME}_L200_seqlen_${SEQ_LEN}_bs_${BATCH_SIZE}_grad_accum_${GRAD_ACCUM}_lr_${LR}_min_lr_ratio_${MIN_LR_RATIO}_warmup_ratio_${WARMUP_RATIO}"
if [ ! -z "${TAG:-}" ]; then
    export RUN_NAME="${RUN_NAME}_${TAG}"
fi
export JAX_PLATFORMS=tpu
export SPARSE_MODEL_TRAINING=False

export PYTHONPATH=./src:${PYTHONPATH:-''}
python -u multihost_runner_orig.py \
    --TPU_PREFIX=${TPU_PREFIX} \
    --COMMAND="
    export TPU_LOG_DIR=/home/zephyr/tpu_logs
    export WANDB_API_KEY='7d11bbca76b3081b6bd1efbbcf1572aab26c5d56'
    source ~/maxtext_env_py311/bin/activate
    export PYTHONPATH=./src:\${PYTHONPATH:-''}
    ~/maxtext_env_py311/bin/python -u -m src.MaxText.train src/MaxText/configs/base.yml \
        run_name=${RUN_NAME} \
        base_output_directory=${BASE_OUTPUT_DIRECTORY} \
        dataset_type=grain \
        grain_train_files=${DATA_FILES} \
        grain_file_type='arrayrecord' \
        grain_worker_count=1 \
        enable_data_shuffling=${SHUFFLE} \
        tokenize_train_data=False \
        tokenize_eval_data=False \
        max_target_length=${SEQ_LEN} \
        async_checkpointing=${ASYNC_CHECKPOINTING} \
        model_name=${MODEL_NAME} \
        steps=${NUM_STEPS} \
        per_device_batch_size=${BATCH_SIZE} \
        gradient_accumulation_steps=${GRAD_ACCUM} \
        gradient_clipping_threshold=${GRAD_CLIP} \
        learning_rate=${LR} \
        warmup_steps_fraction=${WARMUP_RATIO} \
        checkpoint_period=500 \
        enable_wandb=True \
        wandb_project_name=llm_pruning \
        wandb_run_name=${TPU_PREFIX}_${RUN_NAME} \
        packing=false \
        sharding_tolerance=0.5 \
    "

w/o wandb logs

I0220 17:10:13.291728 139933146753024 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
Per train step:
 Total TFLOPs: 419.07 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:10:13.300330 139933146753024 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:10:13.362068 139847774729792 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:10:13.366581 139847774729792 grain_pool.py:440] Grain pool will start child processes.
I0220 17:10:13.369205 139847774729792 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:10:16.334197: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:10:16.379936: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:10:17.931806: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:10:19.885891: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:10:30.075978 139933146753024 max_utils.py:654] 
Memstats: After params initialized:
I0220 17:10:30.076180 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:10:30.076251 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:10:30.076312 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:10:30.076370 139933146753024 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:10:44.227162 139933146753024 metric_logger.py:194] completed step: 1, seconds: 16.716, TFLOP/s/device: 25.070, Tokens/s/device: 3920.633, total_weights: 524224, loss: 249.849
I0220 17:10:50.938067 139933146753024 metric_logger.py:194] completed step: 2, seconds: 0.478, TFLOP/s/device: 876.634, Tokens/s/device: 137092.270, total_weights: 524224, loss: 250.093
I0220 17:10:57.648830 139933146753024 metric_logger.py:194] completed step: 3, seconds: 13.688, TFLOP/s/device: 30.616, Tokens/s/device: 4787.852, total_weights: 524224, loss: 249.443
I0220 17:11:04.359475 139933146753024 metric_logger.py:194] completed step: 4, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.214, total_weights: 524224, loss: 247.791
I0220 17:11:11.070475 139933146753024 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.462, total_weights: 524224, loss: 245.704
I0220 17:11:17.781303 139933146753024 metric_logger.py:194] completed step: 6, seconds: 6.713, TFLOP/s/device: 62.423, Tokens/s/device: 9761.967, total_weights: 524224, loss: 242.812
I0220 17:11:24.492303 139933146753024 metric_logger.py:194] completed step: 7, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.032, total_weights: 524224, loss: 242.086
I0220 17:11:31.203217 139933146753024 metric_logger.py:194] completed step: 8, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.048, total_weights: 524224, loss: 238.795
I0220 17:11:37.914210 139933146753024 metric_logger.py:194] completed step: 9, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.406, total_weights: 524224, loss: 235.700
I0220 17:11:44.624889 139933146753024 metric_logger.py:194] completed step: 10, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.869, total_weights: 524224, loss: 230.782
I0220 17:11:51.335738 139933146753024 metric_logger.py:194] completed step: 11, seconds: 6.711, TFLOP/s/device: 62.442, Tokens/s/device: 9765.028, total_weights: 524224, loss: 228.549
I0220 17:11:58.046651 139933146753024 metric_logger.py:194] completed step: 12, seconds: 6.710, TFLOP/s/device: 62.455, Tokens/s/device: 9766.943, total_weights: 524224, loss: 225.022
I0220 17:12:04.757413 139933146753024 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.717, total_weights: 524224, loss: 218.865
I0220 17:12:11.468343 139933146753024 metric_logger.py:194] completed step: 14, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.496, total_weights: 524224, loss: 214.666
I0220 17:12:18.179343 139933146753024 metric_logger.py:194] completed step: 15, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.202, total_weights: 524224, loss: 208.936
I0220 17:12:24.890058 139933146753024 metric_logger.py:194] completed step: 16, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.237, total_weights: 524224, loss: 206.809
I0220 17:12:31.600906 139933146753024 metric_logger.py:194] completed step: 17, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.589, total_weights: 524224, loss: 201.072
I0220 17:12:38.311885 139933146753024 metric_logger.py:194] completed step: 18, seconds: 6.711, TFLOP/s/device: 62.449, Tokens/s/device: 9766.096, total_weights: 524224, loss: 195.263
I0220 17:12:45.023010 139933146753024 metric_logger.py:194] completed step: 19, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.346, total_weights: 524224, loss: 191.112
I0220 17:12:51.733779 139933146753024 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 185.623
I0220 17:12:58.444595 139933146753024 metric_logger.py:194] completed step: 21, seconds: 6.712, TFLOP/s/device: 62.439, Tokens/s/device: 9764.521, total_weights: 524224, loss: 180.439
I0220 17:13:05.155681 139933146753024 metric_logger.py:194] completed step: 22, seconds: 6.710, TFLOP/s/device: 62.450, Tokens/s/device: 9766.285, total_weights: 524224, loss: 178.119
I0220 17:13:11.866602 139933146753024 metric_logger.py:194] completed step: 23, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.330, total_weights: 524224, loss: 170.033
I0220 17:13:18.577209 139933146753024 metric_logger.py:194] completed step: 24, seconds: 6.711, TFLOP/s/device: 62.448, Tokens/s/device: 9765.981, total_weights: 524224, loss: 162.621
I0220 17:13:25.288255 139933146753024 metric_logger.py:194] completed step: 25, seconds: 6.713, TFLOP/s/device: 62.430, Tokens/s/device: 9763.065, total_weights: 524224, loss: 159.374
I0220 17:13:31.999057 139933146753024 metric_logger.py:194] completed step: 26, seconds: 6.709, TFLOP/s/device: 62.466, Tokens/s/device: 9768.813, total_weights: 524224, loss: 154.217
I0220 17:13:38.710025 139933146753024 metric_logger.py:194] completed step: 27, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.294, total_weights: 524224, loss: 148.877
I0220 17:13:45.420726 139933146753024 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.447, Tokens/s/device: 9765.813, total_weights: 524224, loss: 144.224
I0220 17:13:52.131641 139933146753024 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.444, Tokens/s/device: 9765.349, total_weights: 524224, loss: 139.123
I0220 17:13:58.842322 139933146753024 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.543, total_weights: 524224, loss: 135.702

w/ wandb logs

I0220 17:16:03.667288 140548037974016 max_utils.py:695] Total memory size: 17.8 GB, Output size: 6.7 GB, Temp size: 11.1 GB, Argument size: 6.7 GB, Host temp size: 0.0 GB.
wandb: [wandb.login()] Loaded credentials for https://api.wandb.ai from WANDB_API_KEY.
wandb: Currently logged in as: yx3038 (yx3038-new-york-university) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.25.0
wandb: Run data is saved locally in /home/zephyr/2026-02-20-17-14-43/wandb/run-20260220_171603-k0646bs7
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run yufeng-qw-v4-16_2_qwen3-0.6b_L200_seqlen_8192_bs_1_grad_accum_8_lr_0.0003_min_lr_ratio_0.1_warmup_ratio_0.05
wandb: ⭐️ View project at https://wandb.ai/yx3038-new-york-university/llm_pruning
wandb: 🚀 View run at https://wandb.ai/yx3038-new-york-university/llm_pruning/runs/k0646bs7
Per train step:
 Total TFLOPs: 419.07 
 split as 55.92% learnable weight flops and 44.08% attention flops
I0220 17:16:04.903574 140548037974016 metric_logger.py:298] number parameters: 0.596 billion
I0220 17:16:04.963233 140462667576896 grain_pool.py:367] Grain pool will use 1 processes.
I0220 17:16:04.969832 140462667576896 grain_pool.py:440] Grain pool will start child processes.
I0220 17:16:04.972459 140462667576896 grain_pool.py:448] Grain pool started all child processes.
2026-02-20 17:16:08.003286: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2026-02-20 17:16:08.051091: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2026-02-20 17:16:09.607792: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
PyTorch was not found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
2026-02-20 17:16:11.587793: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
I0220 17:16:21.990377 140548037974016 max_utils.py:654] 
Memstats: After params initialized:
I0220 17:16:21.990737 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_0(process=0,(0,0,0,0))
I0220 17:16:21.991039 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_1(process=0,(1,0,0,0))
I0220 17:16:21.991271 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_2(process=0,(0,1,0,0))
I0220 17:16:21.991589 140548037974016 max_utils.py:660]         Using (GB) 6.7 / 30.75 (21.788618%) on TPU_3(process=0,(1,1,0,0))
I0220 17:16:34.658525 140548037974016 metric_logger.py:194] completed step: 1, seconds: 17.029, TFLOP/s/device: 24.609, Tokens/s/device: 3848.473, total_weights: 524224, loss: 249.849
I0220 17:16:41.369417 140548037974016 metric_logger.py:194] completed step: 2, seconds: 0.497, TFLOP/s/device: 843.702, Tokens/s/device: 131942.291, total_weights: 524224, loss: 250.093
I0220 17:16:48.080118 140548037974016 metric_logger.py:194] completed step: 3, seconds: 12.190, TFLOP/s/device: 34.378, Tokens/s/device: 5376.157, total_weights: 524224, loss: 249.443
I0220 17:16:54.790829 140548037974016 metric_logger.py:194] completed step: 4, seconds: 6.709, TFLOP/s/device: 62.462, Tokens/s/device: 9768.093, total_weights: 524224, loss: 247.791
I0220 17:17:01.501564 140548037974016 metric_logger.py:194] completed step: 5, seconds: 6.710, TFLOP/s/device: 62.458, Tokens/s/device: 9767.429, total_weights: 524224, loss: 245.704
I0220 17:17:08.212548 140548037974016 metric_logger.py:194] completed step: 6, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.154, total_weights: 524224, loss: 242.812
I0220 17:17:14.923486 140548037974016 metric_logger.py:194] completed step: 7, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.342, total_weights: 524224, loss: 242.086
I0220 17:17:21.636461 140548037974016 metric_logger.py:194] completed step: 8, seconds: 6.714, TFLOP/s/device: 62.421, Tokens/s/device: 9761.641, total_weights: 524224, loss: 238.795
I0220 17:17:28.345310 140548037974016 metric_logger.py:194] completed step: 9, seconds: 6.709, TFLOP/s/device: 62.468, Tokens/s/device: 9769.004, total_weights: 524224, loss: 235.700
I0220 17:17:35.056111 140548037974016 metric_logger.py:194] completed step: 10, seconds: 6.716, TFLOP/s/device: 62.401, Tokens/s/device: 9758.650, total_weights: 524224, loss: 230.782
I0220 17:17:41.766973 140548037974016 metric_logger.py:194] completed step: 11, seconds: 6.706, TFLOP/s/device: 62.489, Tokens/s/device: 9772.269, total_weights: 524224, loss: 228.549
I0220 17:17:48.477904 140548037974016 metric_logger.py:194] completed step: 12, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.521, total_weights: 524224, loss: 225.022
I0220 17:17:55.188795 140548037974016 metric_logger.py:194] completed step: 13, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.652, total_weights: 524224, loss: 218.865
I0220 17:18:01.899731 140548037974016 metric_logger.py:194] completed step: 14, seconds: 6.710, TFLOP/s/device: 62.453, Tokens/s/device: 9766.703, total_weights: 524224, loss: 214.666
I0220 17:18:08.610642 140548037974016 metric_logger.py:194] completed step: 15, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.824, total_weights: 524224, loss: 208.936
I0220 17:18:15.321358 140548037974016 metric_logger.py:194] completed step: 16, seconds: 6.712, TFLOP/s/device: 62.436, Tokens/s/device: 9764.086, total_weights: 524224, loss: 206.809
I0220 17:18:22.032296 140548037974016 metric_logger.py:194] completed step: 17, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.665, total_weights: 524224, loss: 201.072
I0220 17:18:28.743230 140548037974016 metric_logger.py:194] completed step: 18, seconds: 6.708, TFLOP/s/device: 62.470, Tokens/s/device: 9769.367, total_weights: 524224, loss: 195.263
I0220 17:18:35.454278 140548037974016 metric_logger.py:194] completed step: 19, seconds: 6.712, TFLOP/s/device: 62.440, Tokens/s/device: 9764.703, total_weights: 524224, loss: 191.112
I0220 17:18:42.165038 140548037974016 metric_logger.py:194] completed step: 20, seconds: 6.711, TFLOP/s/device: 62.443, Tokens/s/device: 9765.207, total_weights: 524224, loss: 185.623
I0220 17:18:48.876030 140548037974016 metric_logger.py:194] completed step: 21, seconds: 6.710, TFLOP/s/device: 62.451, Tokens/s/device: 9766.468, total_weights: 524224, loss: 180.439
I0220 17:18:55.586965 140548037974016 metric_logger.py:194] completed step: 22, seconds: 6.711, TFLOP/s/device: 62.441, Tokens/s/device: 9764.811, total_weights: 524224, loss: 178.119
I0220 17:19:02.298099 140548037974016 metric_logger.py:194] completed step: 23, seconds: 6.715, TFLOP/s/device: 62.405, Tokens/s/device: 9759.255, total_weights: 524224, loss: 170.033
I0220 17:19:09.008782 140548037974016 metric_logger.py:194] completed step: 24, seconds: 6.706, TFLOP/s/device: 62.494, Tokens/s/device: 9773.089, total_weights: 524224, loss: 162.621
I0220 17:19:15.719572 140548037974016 metric_logger.py:194] completed step: 25, seconds: 6.717, TFLOP/s/device: 62.390, Tokens/s/device: 9756.924, total_weights: 524224, loss: 159.374
I0220 17:19:22.430507 140548037974016 metric_logger.py:194] completed step: 26, seconds: 6.706, TFLOP/s/device: 62.491, Tokens/s/device: 9772.699, total_weights: 524224, loss: 154.217
I0220 17:19:29.141228 140548037974016 metric_logger.py:194] completed step: 27, seconds: 6.710, TFLOP/s/device: 62.457, Tokens/s/device: 9767.282, total_weights: 524224, loss: 148.877
I0220 17:19:35.852080 140548037974016 metric_logger.py:194] completed step: 28, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.672, total_weights: 524224, loss: 144.224
I0220 17:19:42.562924 140548037974016 metric_logger.py:194] completed step: 29, seconds: 6.711, TFLOP/s/device: 62.446, Tokens/s/device: 9765.540, total_weights: 524224, loss: 139.123
I0220 17:19:49.273759 140548037974016 metric_logger.py:194] completed step: 30, seconds: 6.711, TFLOP/s/device: 62.445, Tokens/s/device: 9765.534, total_weights: 524224, loss: 135.702

conclusion: we can see from the logs that the average training time per step is ~ 6.71s w/ and w/o wandb. Therefore, wandb hardly affects the training performance in my setups. Please let me know if you think testing with more setups or more tpus is necessary.

Logging more metrics

The metrics I log closely follows the implementation for tensorboard logging (i.e., I'm logging exactly the same metrics as tensorboard logging). Below is a list of the currently supported metrics:

  • learning/total_weights
  • learning/raw_grad_norm
  • learning/param_norm
  • learning/mtp_loss
  • learning/moe_lb_loss
  • learning/loss
  • learning/grad_norm
  • learning/current_learning_rate
  • perf/step_time_seconds
  • perf/per_device_tokens_per_sec
  • perf/per_device_tokens
  • perf/per_device_tflops_per_sec
  • perf/per_device_tflops
  • Network Traffic (Bytes)
  • Disk Utilization (GB)
  • Disk Utilization (%)
  • Process CPU Threads in Use
  • Process Memory Available (MB)
  • Process Memory in Use (%)
  • Process Memory in Use (MB)
  • System Memory Utilization (%)
  • TPU Memory Usage (Bytes)
  • TPU Memory Usage (%)
  • TPU Duty Cycle (%)
  • Process CPU Utilization (%)

Please let me know if any other metrics should be added to this list. Personally I think we may add RL metrics in subsequent PRs.

Additionally, if you want to reproduce the experiments I've done, you may want to go to this commit.

@dipannita08
Copy link
Copy Markdown
Collaborator

Thank you @Zephyr271828 ! Will take another look today!

Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overall looks good, could you please address the comments and rebase the code.

Comment thread src/MaxText/pyconfig.py Outdated
if key not in valid_fields:
logger.warning("Ignoring invalid/unsupported field from YAML/CLI: %s", repr(key))
raise ValueError(f"{key!r} not in {", ".join(map(repr, valid_fields))}.")
raise ValueError(f"{key!r} not in {', '.join(map(repr, valid_fields))}.")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxText uses " " for strings

Comment thread src/maxtext/common/metric_logger.py Outdated
Comment on lines +136 to +137
if self.enable_wandb:
self.write_metrics_to_wandb(metrics, step)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you check here that jax.process_index() == 0?
e.g.

if self.enable_wandb and self.config.enable_wandb and jax.process_index() == 0:
   self.write_metrics_to_wandb(metrics, step)

Copy link
Copy Markdown
Collaborator

@dipannita08 dipannita08 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the additional testing @Zephyr271828! This looks great, could you address last couple comments and rebase?

@Zephyr271828
Copy link
Copy Markdown
Author

Thank you for the additional testing @Zephyr271828! This looks great, could you address last couple comments and rebase?

Sorry I did not see the previous comment. Will finish that by the end of this week:)

@Zephyr271828
Copy link
Copy Markdown
Author

Thank you for the additional testing @Zephyr271828! This looks great, could you address last couple comments and rebase?

Hi! @dipannita08 @gagika I just addressed the latest comments and rebased my commits. Thank you so much for your reviews!

Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks

@github-actions
Copy link
Copy Markdown

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label Apr 26, 2026
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 3, 2026

This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it.

@github-actions github-actions Bot closed this May 3, 2026
@dipannita08 dipannita08 reopened this May 27, 2026
@dipannita08 dipannita08 removed the stale Automatically applied to stale PRs. label May 27, 2026
@shralex
Copy link
Copy Markdown
Collaborator

shralex commented May 27, 2026

@Zephyr271828 it seems that this was automatically closed instead of merging :( could you please rebase ?

@Zephyr271828
Copy link
Copy Markdown
Author

@shralex Hi I just rebased the PR. Let me know if I need to change anything :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants